package db

import (
	"fmt"
	"strings"
	"time"

	version "github.com/hashicorp/go-version"
	rpmver "github.com/knqyf263/go-rpm-version"
	log "github.com/vulsio/go-cve-dictionary/log"
	"github.com/vulsio/go-cve-dictionary/models"
	"golang.org/x/xerrors"

	"github.com/knqyf263/go-cpe/common"
	"github.com/knqyf263/go-cpe/matching"
	"github.com/knqyf263/go-cpe/naming"
)

// DB is interface for a database driver
type DB interface {
	Name() string
	OpenDB(string, string, bool, Option) error
	CloseDB() error
	MigrateDB() error

	IsGoCVEDictModelV1() (bool, error)
	GetFetchMeta() (*models.FetchMeta, error)
	UpsertFetchMeta(*models.FetchMeta) error

	Get(string) (*models.CveDetail, error)
	GetMulti([]string) (map[string]models.CveDetail, error)
	GetCveIDsByCpeURI(string) ([]string, []string, error)
	GetByCpeURI(string) ([]models.CveDetail, error)
	InsertJvn([]string) error
	InsertNvd([]string) error
	CountNvd() (int, error)
	CountJvn() (int, error)
}

// Option :
type Option struct {
	RedisTimeout time.Duration
}

// NewDB return DB accessor.
func NewDB(dbType, dbPath string, debugSQL bool, option Option) (driver DB, err error) {
	if driver, err = newDB(dbType); err != nil {
		log.Errorf("Failed to new db. err: %s", err)
		return driver, xerrors.Errorf("Failed to new db. err: %w", err)
	}

	if err := driver.OpenDB(dbType, dbPath, debugSQL, option); err != nil {
		return nil, xerrors.Errorf("Failed to open db. err: %w", err)
	}

	isV1, err := driver.IsGoCVEDictModelV1()
	if err != nil {
		return nil, xerrors.Errorf("Failed to IsGoCVEDictModelV1. err: %w", err)
	}
	if isV1 {
		return nil, xerrors.New("Failed to NewDB. Since SchemaVersion is incompatible, delete Database and fetch again.")
	}

	if err := driver.MigrateDB(); err != nil {
		return driver, xerrors.Errorf("Failed to migrate db. err: %w", err)
	}

	return driver, nil
}

func newDB(dbType string) (DB, error) {
	switch dbType {
	case dialectSqlite3, dialectMysql, dialectPostgreSQL:
		return &RDBDriver{name: dbType}, nil
	case dialectRedis:
		return &RedisDriver{name: dbType}, nil
	}
	return nil, fmt.Errorf("Invalid database dialect: %s", dbType)
}

// IndexChunk has a starting point and an ending point for Chunk
type IndexChunk struct {
	From, To int
}

func chunkSlice(length int, chunkSize int) <-chan IndexChunk {
	ch := make(chan IndexChunk)

	go func() {
		defer close(ch)

		for i := 0; i < length; i += chunkSize {
			idx := IndexChunk{i, i + chunkSize}
			if length < idx.To {
				idx.To = length
			}
			ch <- idx
		}
	}()

	return ch
}

func parseCpeURI(cpe22uri string) (*models.CpeBase, error) {
	wfn, err := naming.UnbindURI(cpe22uri)
	if err != nil {
		return nil, err
	}

	return &models.CpeBase{
		URI:             naming.BindToURI(wfn),
		FormattedString: naming.BindToFS(wfn),
		WellFormedName:  wfn.String(),
		CpeWFN: models.CpeWFN{
			Part:            wfn.GetString(common.AttributePart),
			Vendor:          wfn.GetString(common.AttributeVendor),
			Product:         wfn.GetString(common.AttributeProduct),
			Version:         wfn.GetString(common.AttributeVersion),
			Update:          wfn.GetString(common.AttributeUpdate),
			Edition:         wfn.GetString(common.AttributeEdition),
			Language:        wfn.GetString(common.AttributeLanguage),
			SoftwareEdition: wfn.GetString(common.AttributeSwEdition),
			TargetSW:        wfn.GetString(common.AttributeTargetSw),
			TargetHW:        wfn.GetString(common.AttributeTargetHw),
			Other:           wfn.GetString(common.AttributeOther),
		},
	}, nil
}

func makeVersionConstraint(cpeInNvd models.CpeBase) (string, error) {
	constraints := []string{}
	if cpeInNvd.VersionStartIncluding != "" {
		_, err := version.NewSemver(cpeInNvd.VersionStartIncluding)
		if err != nil {
			log.Debugf("Failed to parse the semver: %s, err: %s", cpeInNvd.VersionStartIncluding, err)
			return "", err
		}
		constraints = append(constraints, ">= "+cpeInNvd.VersionStartIncluding)
	}
	if cpeInNvd.VersionStartExcluding != "" {
		_, err := version.NewSemver(cpeInNvd.VersionStartExcluding)
		if err != nil {
			log.Debugf("Failed to parse the semver: %s, err: %s", cpeInNvd.VersionStartExcluding, err)
			return "", err
		}
		constraints = append(constraints, "> "+cpeInNvd.VersionStartExcluding)
	}
	if cpeInNvd.VersionEndIncluding != "" {
		_, err := version.NewSemver(cpeInNvd.VersionEndIncluding)
		if err != nil {
			log.Debugf("Failed to parse the semver: %s, err: %s", cpeInNvd.VersionEndIncluding, err)
			return "", err
		}
		constraints = append(constraints, "<= "+cpeInNvd.VersionEndIncluding)
	}
	if cpeInNvd.VersionEndExcluding != "" {
		_, err := version.NewSemver(cpeInNvd.VersionEndExcluding)
		if err != nil {
			log.Debugf("Failed to parse the semver: %s, err: %s", cpeInNvd.VersionEndExcluding, err)
			return "", err
		}
		constraints = append(constraints, "< "+cpeInNvd.VersionEndExcluding)
	}
	return strings.Join(constraints, ", "), nil
}

func isSamePartVendorProduct(cpeA, cpeB string) (bool, error) {
	a, err := naming.UnbindURI(cpeA)
	if err != nil {
		return false, xerrors.Errorf("Failed to unbind. CPE: %s. err: %w", cpeA, err)
	}

	b, err := naming.UnbindURI(cpeB)
	if err != nil {
		return false, xerrors.Errorf("Failed to unbind. CPE: %s. err: %w", cpeB, err)
	}

	if a.Get(common.AttributePart) == b.Get(common.AttributePart) &&
		a.Get(common.AttributeVendor) == b.Get(common.AttributeVendor) &&
		a.Get(common.AttributeProduct) == b.Get(common.AttributeProduct) {
		return true, nil
	}
	return false, nil
}

func match(specifiedURI string, cpeInNvd models.CpeBase) (isExactVerMatch, isRoughVerMatch, isVendorProductMatch bool, err error) {
	specified, err := naming.UnbindURI(specifiedURI)
	if err != nil {
		return false, false, false, xerrors.Errorf("Failed to unbind. CPE: %s. err: %w", specifiedURI, err)
	}

	cpeInNvdWfn, err := naming.UnbindURI(cpeInNvd.URI)
	if err != nil {
		return false, false, false, xerrors.Errorf("Failed to unbind. CPE: %s. err: %w", cpeInNvd.URI, err)
	}

	if cpeInNvdWfn.Get(common.AttributePart) != specified.Get(common.AttributePart) ||
		cpeInNvdWfn.Get(common.AttributeVendor) != specified.Get(common.AttributeVendor) ||
		cpeInNvdWfn.Get(common.AttributeProduct) != specified.Get(common.AttributeProduct) {
		return false, false, false, nil
	}

	specifiedVer := specified.GetString(common.AttributeVersion)
	switch specifiedVer {
	case "NA", "ANY":
		if err := cpeInNvdWfn.Set(common.AttributeVersion, nil); err != nil {
			return false, false, false, err
		}
		return false, false, isSuperORSubset(cpeInNvdWfn, specified), nil
	}

	if matching.IsEqual(specified, cpeInNvdWfn) {
		log.Debugf("%s equals %s", specified.String(), cpeInNvd.URI)
		return true, false, false, nil
	}

	if cpeInNvdWfn.GetString(common.AttributeVersion) == "NA" {
		log.Debugf("%s matches %s", specified.String(), cpeInNvd.URI)
		return true, false, false, nil
	}

	ok, err := matchSemver(specifiedVer, cpeInNvd)
	if err != nil {
		// version range specified in cpeInNvd are not defined as semver style
		// So, we assume it is in rpm format and try to check the version.
		// False positives will occur if they do not fit into the rpm version comparison method.
		if ok := matchRpmVer(specifiedVer, cpeInNvd); ok {
			return false, isSuperORSubset(cpeInNvdWfn, specified), false, nil
		}
	}
	if ok {
		return isSuperORSubset(cpeInNvdWfn, specified), false, false, nil
	}

	// If the specified version is not as a range, but as a fixed value
	//
	// return true in this case:
	// - config.toml:  	"cpe:/a:apache:cordova:5.1.1::~~~iphone_os~~",
	// - AffectedCPEInNVD:    "cpe:/a:apache:cordova:5.1.1",
	//
	// In this case, target_sw does not match and returns false
	// - config.toml:  	"cpe:/a:apache:cordova:5.1.1::~~~iphone_os~~",
	// - AffectedCPEInNVD:    "cpe:/a:apache:cordova:5.1.1::~~~android~~",
	if specified.GetString(common.AttributeVersion) != cpeInNvdWfn.GetString(common.AttributeVersion) {
		return false, false, false, nil
	}
	return isSuperORSubset(cpeInNvdWfn, specified), false, false, nil
}

func matchSemver(specifiedVer string, cpeInNvd models.CpeBase) (ok bool, err error) {
	constraintStr, err := makeVersionConstraint(cpeInNvd)
	if err != nil {
		return false, err
	}
	if constraintStr == "" {
		return false, nil
	}

	constraints, err := version.NewConstraint(constraintStr)
	if err != nil {
		return false, err
	}
	specifiedVer = strings.Replace(specifiedVer, `\`, "", -1)
	v, err := version.NewSemver(specifiedVer)
	if err != nil {
		log.Debugf("Failed to parse the semver: %s, err: %s", specifiedVer, err)
		return false, err
	}
	if ok = constraints.Check(v); ok {
		log.Debugf("%s satisfies version constraints %s", specifiedVer, constraintStr)
	}
	return
}

func matchRpmVer(specifiedVer string, cpeInNvd models.CpeBase) bool {
	specified := rpmver.NewVersion(specifiedVer)

	// verStart <= specified
	if cpeInNvd.VersionStartIncluding != "" {
		ver := rpmver.NewVersion(cpeInNvd.VersionStartIncluding)
		if !specified.Equal(ver) && specified.LessThan(ver) {
			return false
		}
	}
	// verStart < specified
	if cpeInNvd.VersionStartExcluding != "" {
		ver := rpmver.NewVersion(cpeInNvd.VersionStartExcluding)
		if specified.Equal(ver) || specified.LessThan(ver) {
			return false
		}
	}
	// specified <= verEnd
	if cpeInNvd.VersionEndIncluding != "" {
		ver := rpmver.NewVersion(cpeInNvd.VersionEndIncluding)
		if !specified.Equal(ver) && ver.LessThan(specified) {
			return false
		}
	}
	// specified < verEnd
	if cpeInNvd.VersionEndExcluding != "" {
		ver := rpmver.NewVersion(cpeInNvd.VersionEndExcluding)
		if specified.Equal(ver) || ver.LessThan(specified) {
			return false
		}
	}
	return true
}

func matchCpe(uri string, cve *models.CveDetail) (nvdMatch, jvnMatch bool, err error) {
	for _, nvd := range cve.Nvds {
		for _, cpe := range nvd.Cpes {
			isExactMatch, isRoughMatch, isVendorProductMatch, err := match(uri, cpe.CpeBase)
			if err != nil {
				log.Debugf("Failed to match: %s", err)
				continue
			}
			if isExactMatch || isRoughMatch || isVendorProductMatch {
				return true, false, nil
			}
		}
	}

	if !cve.HasJvn() {
		return false, false, nil
	}

	// CPE that exists only in JVN is also detected.
	// There is a possibility of false positives since the JVN does not contain version information.
	for _, jvn := range cve.Jvns {
		for _, jvnCpe := range jvn.Cpes {
			// If NVD has data of the same `part`, `vendor`, and `product`, NVD is used in priority.
			// Because NVD has version information, but JVN does not.
			if isCpeURIAlsoDefinedInNvd(jvnCpe.URI, cve.Nvds) {
				continue
			}

			ok, err := isSamePartVendorProduct(uri, jvnCpe.URI)
			if err != nil {
				continue
			}
			if ok {
				return false, true, nil
			}
		}
	}

	return false, false, nil
}

func isSuperORSubset(source, target common.WellFormedName) bool {
	anyval, _ := common.NewLogicalValue("ANY")
	_ = target.Set(common.AttributeVersion, anyval)
	_ = source.Set(common.AttributeVersion, anyval)
	if matching.IsSuperset(source, target) {
		log.Debugf("%s is superset of %s", source.String(), target.String())
		return true
	}
	if matching.IsSubset(source, target) {
		log.Debugf("%s is subset of %s", source.String(), target.String())
		return true
	}
	return false
}

func trimBSlash(s string) string {
	return strings.Replace(s, `\`, "", -1)
}

func filterCveDetailByCpeURI(uri string, d *models.CveDetail) error {
	if d == nil {
		return nil
	}

	nvds := append([]models.Nvd{}, d.Nvds...)
	d.Nvds = []models.Nvd{}
	for _, nvd := range nvds {
		for _, cpe := range nvd.Cpes {
			isExactMatch, isRoughMatch, isVendorProductMatch, err := match(uri, cpe.CpeBase)
			if err != nil {
				log.Debugf("Failed to match. err: %s, uri: %s, CpeBase: %#v", err, uri, cpe.CpeBase)
			}
			if isExactMatch {
				nvd.DetectionMethod = models.NvdExactVersionMatch
			} else if isRoughMatch {
				nvd.DetectionMethod = models.NvdRoughVersionMatch
			} else if isVendorProductMatch {
				nvd.DetectionMethod = models.NvdVendorProductMatch
			}
			if isExactMatch || isRoughMatch || isVendorProductMatch {
				d.Nvds = append(d.Nvds, nvd)
				break
			}
		}
	}

	jvns := append([]models.Jvn{}, d.Jvns...)
	d.Jvns = []models.Jvn{}
	for _, jvn := range jvns {
		for _, cpe := range jvn.Cpes {
			// If NVD has data of the same `part`, `vendor`, and `product`, NVD is used in priority.
			// Because NVD has version information, but JVN does not.
			if isCpeURIAlsoDefinedInNvd(cpe.URI, nvds) {
				continue
			}
			matched, err := isSamePartVendorProduct(uri, cpe.URI)
			if err != nil {
				continue
			}
			if matched {
				jvn.DetectionMethod = models.JvnVendorProductMatch
				d.Jvns = append(d.Jvns, jvn)
				break
			}
		}
	}

	return nil
}

func isCpeURIAlsoDefinedInNvd(cpeURI string, nvds []models.Nvd) bool {
	for _, nvd := range nvds {
		for _, nvdCpe := range nvd.Cpes {
			ok, err := isSamePartVendorProduct(nvdCpe.URI, cpeURI)
			if err != nil {
				continue
			}
			if ok {
				return true
			}
		}
	}
	return false
}
